{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\n    @name: King\\n    @title: Stochastic Training for Graph Convolutional Networks\\n    @url: https://github.com/dmlc/dgl/tree/master/examples/pytorch/sampling\\n    \\n'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'''\n",
    "    @name: King\n",
    "    @title: Stochastic Training for Graph Convolutional Networks\n",
    "    @url: https://github.com/dmlc/dgl/tree/master/examples/pytorch/sampling\n",
    "    \n",
    "'''"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Stochastic Training for Graph Convolutional Networks\n",
    "\n",
    "- Paper : [Control Variate](https://arxiv.org/abs/1710.10568)\n",
    "- Paper : [Skip Connection](https://arxiv.org/abs/1809.05343)\n",
    "- Author's code : [https://github.com/thu-ml/stochastic_gcn](https://github.com/thu-ml/stochastic_gcn)\n",
    "\n",
    "## Dependencies\n",
    "\n",
    "- PyTorch 0.4.1+\n",
    "- requests\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 一、包引入"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse, time, math\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from functools import partial\n",
    "import dgl\n",
    "import dgl.function as fn\n",
    "from dgl import DGLGraph\n",
    "from dgl.data import register_data_args, load_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NodeUpdate(nn.Module):\n",
    "    def __init__(self, in_feats, out_feats, activation=None, test=False, concat=False):\n",
    "        super(NodeUpdate, self).__init__()\n",
    "        self.linear = nn.Linear(in_feats, out_feats)\n",
    "        self.activation = activation\n",
    "        self.concat = concat\n",
    "        self.test = test\n",
    "\n",
    "    def forward(self, node):\n",
    "        h = node.data['h']\n",
    "        if self.test:\n",
    "            h = h * node.data['norm']\n",
    "        h = self.linear(h)\n",
    "        # skip connection\n",
    "        if self.concat:\n",
    "            h = torch.cat((h, self.activation(h)), dim=1)\n",
    "        elif self.activation:\n",
    "            h = self.activation(h)\n",
    "        return {'activation': h}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GCNSampling(nn.Module):\n",
    "    def __init__(self,\n",
    "                 in_feats,\n",
    "                 n_hidden,\n",
    "                 n_classes,\n",
    "                 n_layers,\n",
    "                 activation,\n",
    "                 dropout):\n",
    "        super(GCNSampling, self).__init__()\n",
    "        self.n_layers = n_layers\n",
    "        if dropout != 0:\n",
    "            self.dropout = nn.Dropout(p=dropout)\n",
    "        else:\n",
    "            self.dropout = None\n",
    "        self.layers = nn.ModuleList()\n",
    "        # input layer\n",
    "        skip_start = (0 == n_layers-1)\n",
    "        self.layers.append(NodeUpdate(in_feats, n_hidden, activation, concat=skip_start))\n",
    "        # hidden layers\n",
    "        for i in range(1, n_layers):\n",
    "            skip_start = (i == n_layers-1)\n",
    "            self.layers.append(NodeUpdate(n_hidden, n_hidden, activation, concat=skip_start))\n",
    "        # output layer\n",
    "        self.layers.append(NodeUpdate(2*n_hidden, n_classes))\n",
    "\n",
    "    def forward(self, nf):\n",
    "        nf.layers[0].data['activation'] = nf.layers[0].data['features']\n",
    "\n",
    "        for i, layer in enumerate(self.layers):\n",
    "            h = nf.layers[i].data.pop('activation')\n",
    "            if self.dropout:\n",
    "                h = self.dropout(h)\n",
    "            nf.layers[i].data['h'] = h\n",
    "            nf.block_compute(i,\n",
    "                             fn.copy_src(src='h', out='m'),\n",
    "                             lambda node : {'h': node.mailbox['m'].mean(dim=1)},\n",
    "                             layer)\n",
    "\n",
    "        h = nf.layers[-1].data.pop('activation')\n",
    "        return h"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GCNInfer(nn.Module):\n",
    "    def __init__(self,\n",
    "                 in_feats,\n",
    "                 n_hidden,\n",
    "                 n_classes,\n",
    "                 n_layers,\n",
    "                 activation):\n",
    "        super(GCNInfer, self).__init__()\n",
    "        self.n_layers = n_layers\n",
    "        self.layers = nn.ModuleList()\n",
    "        # input layer\n",
    "        skip_start = (0 == n_layers-1)\n",
    "        self.layers.append(NodeUpdate(in_feats, n_hidden, activation, test=True, concat=skip_start))\n",
    "        # hidden layers\n",
    "        for i in range(1, n_layers):\n",
    "            skip_start = (i == n_layers-1)\n",
    "            self.layers.append(NodeUpdate(n_hidden, n_hidden, activation, test=True, concat=skip_start))\n",
    "        # output layer\n",
    "        self.layers.append(NodeUpdate(2*n_hidden, n_classes, test=True))\n",
    "\n",
    "    def forward(self, nf):\n",
    "        nf.layers[0].data['activation'] = nf.layers[0].data['features']\n",
    "\n",
    "        for i, layer in enumerate(self.layers):\n",
    "            h = nf.layers[i].data.pop('activation')\n",
    "            nf.layers[i].data['h'] = h\n",
    "            nf.block_compute(i,\n",
    "                             fn.copy_src(src='h', out='m'),\n",
    "                             fn.sum(msg='m', out='h'),\n",
    "                             layer)\n",
    "\n",
    "        h = nf.layers[-1].data.pop('activation')\n",
    "        return h"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main(args):\n",
    "    # load and preprocess dataset\n",
    "    data = load_data(args)\n",
    "\n",
    "    if args.self_loop and not args.dataset.startswith('reddit'):\n",
    "        data.graph.add_edges_from([(i,i) for i in range(len(data.graph))])\n",
    "\n",
    "    train_nid = np.nonzero(data.train_mask)[0].astype(np.int64)\n",
    "    test_nid = np.nonzero(data.test_mask)[0].astype(np.int64)\n",
    "\n",
    "    features = torch.FloatTensor(data.features)\n",
    "    labels = torch.LongTensor(data.labels)\n",
    "    train_mask = torch.ByteTensor(data.train_mask)\n",
    "    val_mask = torch.ByteTensor(data.val_mask)\n",
    "    test_mask = torch.ByteTensor(data.test_mask)\n",
    "    in_feats = features.shape[1]\n",
    "    n_classes = data.num_labels\n",
    "    n_edges = data.graph.number_of_edges()\n",
    "\n",
    "    n_train_samples = train_mask.sum().item()\n",
    "    n_val_samples = val_mask.sum().item()\n",
    "    n_test_samples = test_mask.sum().item()\n",
    "\n",
    "    print(\"\"\"----Data statistics------'\n",
    "      #Edges %d\n",
    "      #Classes %d\n",
    "      #Train samples %d\n",
    "      #Val samples %d\n",
    "      #Test samples %d\"\"\" %\n",
    "          (n_edges, n_classes,\n",
    "              n_train_samples,\n",
    "              n_val_samples,\n",
    "              n_test_samples))\n",
    "\n",
    "    # create GCN model\n",
    "    g = DGLGraph(data.graph, readonly=True)\n",
    "    norm = 1. / g.in_degrees().float().unsqueeze(1)\n",
    "\n",
    "    if args.gpu < 0:\n",
    "        cuda = False\n",
    "    else:\n",
    "        cuda = True\n",
    "        torch.cuda.set_device(args.gpu)\n",
    "        features = features.cuda()\n",
    "        labels = labels.cuda()\n",
    "        train_mask = train_mask.cuda()\n",
    "        val_mask = val_mask.cuda()\n",
    "        test_mask = test_mask.cuda()\n",
    "        norm = norm.cuda()\n",
    "\n",
    "    g.ndata['features'] = features\n",
    "\n",
    "    num_neighbors = args.num_neighbors\n",
    "\n",
    "    g.ndata['norm'] = norm\n",
    "\n",
    "    model = GCNSampling(in_feats,\n",
    "                        args.n_hidden,\n",
    "                        n_classes,\n",
    "                        args.n_layers,\n",
    "                        F.relu,\n",
    "                        args.dropout)\n",
    "\n",
    "    if cuda:\n",
    "        model.cuda()\n",
    "\n",
    "    loss_fcn = nn.CrossEntropyLoss()\n",
    "\n",
    "    infer_model = GCNInfer(in_feats,\n",
    "                           args.n_hidden,\n",
    "                           n_classes,\n",
    "                           args.n_layers,\n",
    "                           F.relu)\n",
    "\n",
    "    if cuda:\n",
    "        infer_model.cuda()\n",
    "\n",
    "    # use optimizer\n",
    "    optimizer = torch.optim.Adam(model.parameters(),\n",
    "                                 lr=args.lr,\n",
    "                                 weight_decay=args.weight_decay)\n",
    "\n",
    "    # initialize graph\n",
    "    dur = []\n",
    "    for epoch in range(args.n_epochs):\n",
    "        for nf in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,\n",
    "                                                       args.num_neighbors,\n",
    "                                                       neighbor_type='in',\n",
    "                                                       shuffle=True,\n",
    "                                                       num_workers=32,\n",
    "                                                       num_hops=args.n_layers+1,\n",
    "                                                       seed_nodes=train_nid):\n",
    "            nf.copy_from_parent()\n",
    "            model.train()\n",
    "            # forward\n",
    "            pred = model(nf)\n",
    "            batch_nids = nf.layer_parent_nid(-1).to(device=pred.device, dtype=torch.long)\n",
    "            batch_labels = labels[batch_nids]\n",
    "            loss = loss_fcn(pred, batch_labels)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "        for infer_param, param in zip(infer_model.parameters(), model.parameters()):\n",
    "            infer_param.data.copy_(param.data)\n",
    "\n",
    "        num_acc = 0.\n",
    "\n",
    "        for nf in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size,\n",
    "                                                       g.number_of_nodes(),\n",
    "                                                       neighbor_type='in',\n",
    "                                                       num_workers=32,\n",
    "                                                       num_hops=args.n_layers+1,\n",
    "                                                       seed_nodes=test_nid):\n",
    "            nf.copy_from_parent()\n",
    "            infer_model.eval()\n",
    "            with torch.no_grad():\n",
    "                pred = infer_model(nf)\n",
    "                batch_nids = nf.layer_parent_nid(-1).to(device=pred.device, dtype=torch.long)\n",
    "                batch_labels = labels[batch_nids]\n",
    "                num_acc += (pred.argmax(dim=1) == batch_labels).sum().cpu().item()\n",
    "\n",
    "        print(\"Test Accuracy {:.4f}\". format(num_acc/n_test_samples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ArgsClass():\n",
    "    def __init__(self):\n",
    "        self.dropout = 0.5\n",
    "        self.gpu = -10\n",
    "        self.lr = 3e-2\n",
    "        self.n_epochs = 200\n",
    "        self.batch_size = 1000\n",
    "        self.test_batch_size = 1000\n",
    "        self.num_neighbors = 3\n",
    "        self.n_hidden = 16\n",
    "        self.n_layers = 1\n",
    "        self.self_loop = 'store_true'\n",
    "        self.weight_decay = 5e-4\n",
    "        self.dataset = 'cora'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<__main__.ArgsClass object at 0x00000215005A5198>\n"
     ]
    }
   ],
   "source": [
    "args = ArgsClass()\n",
    "print(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----Data statistics------'\n",
      "      #Edges 13264\n",
      "      #Classes 7\n",
      "      #Train samples 140\n",
      "      #Val samples 300\n",
      "      #Test samples 1000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\progrom\\python\\python\\python3\\lib\\site-packages\\dgl\\base.py:18: UserWarning: Initializer is not set. Use zero initializer instead. To suppress this warning, use `set_initializer` to explicitly specify which initializer to use.\n",
      "  warnings.warn(msg)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Accuracy 0.0720\n",
      "Test Accuracy 0.3960\n",
      "Test Accuracy 0.3750\n",
      "Test Accuracy 0.3940\n",
      "Test Accuracy 0.3260\n",
      "Test Accuracy 0.3090\n",
      "Test Accuracy 0.3090\n",
      "Test Accuracy 0.3090\n",
      "Test Accuracy 0.3090\n",
      "Test Accuracy 0.3090\n",
      "Test Accuracy 0.3100\n",
      "Test Accuracy 0.3220\n",
      "Test Accuracy 0.3550\n",
      "Test Accuracy 0.4020\n",
      "Test Accuracy 0.4590\n",
      "Test Accuracy 0.4840\n",
      "Test Accuracy 0.5130\n",
      "Test Accuracy 0.5630\n",
      "Test Accuracy 0.6080\n",
      "Test Accuracy 0.6480\n",
      "Test Accuracy 0.6870\n",
      "Test Accuracy 0.7210\n",
      "Test Accuracy 0.7360\n",
      "Test Accuracy 0.7480\n",
      "Test Accuracy 0.7600\n",
      "Test Accuracy 0.7590\n",
      "Test Accuracy 0.7550\n",
      "Test Accuracy 0.7620\n",
      "Test Accuracy 0.7760\n",
      "Test Accuracy 0.7900\n",
      "Test Accuracy 0.7940\n",
      "Test Accuracy 0.8000\n",
      "Test Accuracy 0.8080\n",
      "Test Accuracy 0.8160\n",
      "Test Accuracy 0.8240\n",
      "Test Accuracy 0.8290\n",
      "Test Accuracy 0.8300\n",
      "Test Accuracy 0.8270\n",
      "Test Accuracy 0.8230\n",
      "Test Accuracy 0.8260\n",
      "Test Accuracy 0.8230\n",
      "Test Accuracy 0.8320\n",
      "Test Accuracy 0.8280\n",
      "Test Accuracy 0.8300\n",
      "Test Accuracy 0.8400\n",
      "Test Accuracy 0.8410\n",
      "Test Accuracy 0.8340\n",
      "Test Accuracy 0.8220\n",
      "Test Accuracy 0.8260\n",
      "Test Accuracy 0.8270\n",
      "Test Accuracy 0.8310\n",
      "Test Accuracy 0.8290\n",
      "Test Accuracy 0.8230\n",
      "Test Accuracy 0.8310\n",
      "Test Accuracy 0.8420\n",
      "Test Accuracy 0.8390\n",
      "Test Accuracy 0.8350\n",
      "Test Accuracy 0.8340\n",
      "Test Accuracy 0.8180\n",
      "Test Accuracy 0.8270\n",
      "Test Accuracy 0.8320\n",
      "Test Accuracy 0.8340\n",
      "Test Accuracy 0.8440\n",
      "Test Accuracy 0.8440\n",
      "Test Accuracy 0.8420\n",
      "Test Accuracy 0.8410\n",
      "Test Accuracy 0.8350\n",
      "Test Accuracy 0.8310\n",
      "Test Accuracy 0.8240\n",
      "Test Accuracy 0.8240\n",
      "Test Accuracy 0.8260\n",
      "Test Accuracy 0.8290\n",
      "Test Accuracy 0.8330\n",
      "Test Accuracy 0.8370\n",
      "Test Accuracy 0.8380\n",
      "Test Accuracy 0.8420\n",
      "Test Accuracy 0.8470\n",
      "Test Accuracy 0.8470\n",
      "Test Accuracy 0.8370\n",
      "Test Accuracy 0.8320\n",
      "Test Accuracy 0.8310\n",
      "Test Accuracy 0.8340\n",
      "Test Accuracy 0.8470\n",
      "Test Accuracy 0.8480\n",
      "Test Accuracy 0.8500\n",
      "Test Accuracy 0.8500\n",
      "Test Accuracy 0.8520\n",
      "Test Accuracy 0.8480\n",
      "Test Accuracy 0.8440\n",
      "Test Accuracy 0.8350\n",
      "Test Accuracy 0.8360\n",
      "Test Accuracy 0.8480\n",
      "Test Accuracy 0.8370\n",
      "Test Accuracy 0.8440\n",
      "Test Accuracy 0.8420\n",
      "Test Accuracy 0.8410\n",
      "Test Accuracy 0.8220\n",
      "Test Accuracy 0.8160\n",
      "Test Accuracy 0.8190\n",
      "Test Accuracy 0.8310\n",
      "Test Accuracy 0.8380\n",
      "Test Accuracy 0.8420\n",
      "Test Accuracy 0.8380\n",
      "Test Accuracy 0.8410\n",
      "Test Accuracy 0.8390\n",
      "Test Accuracy 0.8300\n",
      "Test Accuracy 0.8250\n",
      "Test Accuracy 0.8230\n",
      "Test Accuracy 0.8310\n",
      "Test Accuracy 0.8380\n",
      "Test Accuracy 0.8350\n",
      "Test Accuracy 0.8400\n",
      "Test Accuracy 0.8360\n",
      "Test Accuracy 0.8410\n",
      "Test Accuracy 0.8390\n",
      "Test Accuracy 0.8280\n",
      "Test Accuracy 0.8340\n",
      "Test Accuracy 0.8280\n",
      "Test Accuracy 0.8310\n",
      "Test Accuracy 0.8360\n",
      "Test Accuracy 0.8380\n",
      "Test Accuracy 0.8380\n",
      "Test Accuracy 0.8390\n",
      "Test Accuracy 0.8380\n",
      "Test Accuracy 0.8310\n",
      "Test Accuracy 0.8390\n",
      "Test Accuracy 0.8420\n",
      "Test Accuracy 0.8340\n",
      "Test Accuracy 0.8330\n",
      "Test Accuracy 0.8300\n",
      "Test Accuracy 0.8330\n",
      "Test Accuracy 0.8400\n",
      "Test Accuracy 0.8370\n",
      "Test Accuracy 0.8430\n",
      "Test Accuracy 0.8440\n",
      "Test Accuracy 0.8380\n",
      "Test Accuracy 0.8310\n",
      "Test Accuracy 0.8350\n",
      "Test Accuracy 0.8350\n",
      "Test Accuracy 0.8340\n",
      "Test Accuracy 0.8400\n",
      "Test Accuracy 0.8410\n",
      "Test Accuracy 0.8400\n",
      "Test Accuracy 0.8450\n",
      "Test Accuracy 0.8470\n",
      "Test Accuracy 0.8350\n",
      "Test Accuracy 0.8270\n",
      "Test Accuracy 0.8260\n",
      "Test Accuracy 0.8340\n",
      "Test Accuracy 0.8300\n",
      "Test Accuracy 0.8320\n",
      "Test Accuracy 0.8380\n",
      "Test Accuracy 0.8430\n",
      "Test Accuracy 0.8550\n",
      "Test Accuracy 0.8540\n",
      "Test Accuracy 0.8520\n",
      "Test Accuracy 0.8490\n",
      "Test Accuracy 0.8380\n",
      "Test Accuracy 0.8350\n",
      "Test Accuracy 0.8310\n",
      "Test Accuracy 0.8310\n",
      "Test Accuracy 0.8290\n",
      "Test Accuracy 0.8290\n",
      "Test Accuracy 0.8420\n",
      "Test Accuracy 0.8470\n",
      "Test Accuracy 0.8400\n",
      "Test Accuracy 0.8330\n",
      "Test Accuracy 0.8360\n",
      "Test Accuracy 0.8390\n",
      "Test Accuracy 0.8510\n",
      "Test Accuracy 0.8540\n",
      "Test Accuracy 0.8570\n",
      "Test Accuracy 0.8570\n",
      "Test Accuracy 0.8480\n",
      "Test Accuracy 0.8430\n",
      "Test Accuracy 0.8380\n",
      "Test Accuracy 0.8330\n",
      "Test Accuracy 0.8320\n",
      "Test Accuracy 0.8400\n",
      "Test Accuracy 0.8410\n",
      "Test Accuracy 0.8300\n",
      "Test Accuracy 0.8370\n",
      "Test Accuracy 0.8420\n",
      "Test Accuracy 0.8410\n",
      "Test Accuracy 0.8380\n",
      "Test Accuracy 0.8370\n",
      "Test Accuracy 0.8390\n",
      "Test Accuracy 0.8510\n",
      "Test Accuracy 0.8480\n",
      "Test Accuracy 0.8370\n",
      "Test Accuracy 0.8380\n",
      "Test Accuracy 0.8260\n",
      "Test Accuracy 0.8180\n",
      "Test Accuracy 0.8050\n",
      "Test Accuracy 0.8220\n",
      "Test Accuracy 0.8310\n",
      "Test Accuracy 0.8270\n",
      "Test Accuracy 0.8310\n",
      "Test Accuracy 0.8330\n",
      "Test Accuracy 0.8390\n"
     ]
    }
   ],
   "source": [
    "main(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
